#@title Self-guidance modules
import torch
import diffusers
import torchvision.transforms.functional as TF
import numpy as np
from IPython.display import Image
from PIL import Image
import os

from utils.utils import view_images, AttentionStore, save_obj_cross_attention, save_mask_as_img, \
                  concat_images, aggregate_attention, Phrase2idx
from utils.inversion import NullInversion, load_512
from utils.SelfGuidanceEdits import SelfGuidanceEdits
from utils.loss import remd_loss_m2

from globals import *
from AttnProcessor import SelfGuidanceAttnProcessor, SelfGuidanceAttnProcessor2_0
from pipeline import SelfGuidanceSD_v1_5_Pipeline


Globals.set('NO_OPT', True)
Globals.set('DEBUG', False)
Globals.set('save_aux', False)
Globals.set('new_timestep', True)
Globals.set('step', 0)
Globals.set('hard_mask_base', None)
Globals.set('prompt_text_ids', None)
Globals.set('attn_path', '')

base:SelfGuidanceSD_v1_5_Pipeline = SelfGuidanceSD_v1_5_Pipeline.from_pretrained(
    args.pm_pretrained_model_name_or_path,  use_safetensors=True, torch_dtype=torch.float32, variant="fp32", use_onnx=False
)
base.to("cuda")

sd_pipe:SelfGuidanceSD_v1_5_Pipeline = SelfGuidanceSD_v1_5_Pipeline.from_pretrained(
    args.sd_pretrained_model_name_or_path,  use_safetensors=True, torch_dtype=torch.float32, variant="fp32", use_onnx=False
)
sd_pipe.to("cuda")

if args.dreambooth:
  pass
elif args.textual_inversion:
  # Load Textual Inversion v1.4
  base.load_textual_inversion("your-textual-inversion-folder", weight_name="your-textual-inversion-weight")
elif args.blip_diffusion:
  pass

attn_store = AttentionStore()
for name, block in base.unet.named_modules():
  if isinstance(block, (diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D, diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D, diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn)):
    for attn_name, attn in block.named_modules():
      full_name = name+'.'+attn_name
      if 'attn2' not in attn_name: continue
      if isinstance(attn, diffusers.models.attention_processor.Attention):
        if isinstance(attn.processor, diffusers.models.attention_processor.AttnProcessor2_0):
          attn.processor = SelfGuidanceAttnProcessor2_0(full_name, attn_store, base)
          print(f'{full_name} change to SelfGuidanceAttnProcessor2_0')
        elif isinstance(attn.processor, diffusers.models.attention_processor.AttnProcessor):
          attn.processor = SelfGuidanceAttnProcessor(full_name, attn_store, base)
          print(f'{full_name} change to SelfGuidanceAttnProcessor')
        else:
          raise NotImplementedError("Self-guidance is not implemented for this attention processor")
        
base.to('cuda')     

attn_store.reset()
for name, block in sd_pipe.unet.named_modules():
  if isinstance(block, (diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D, diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D, diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn)):
    for attn_name, attn in block.named_modules():
      full_name = name+'.'+attn_name
      if 'attn2' not in attn_name: continue
      if isinstance(attn, diffusers.models.attention_processor.Attention):
        if isinstance(attn.processor, diffusers.models.attention_processor.AttnProcessor2_0):
          attn.processor = SelfGuidanceAttnProcessor2_0(full_name, attn_store, sd_pipe)
          print(f'{full_name} change to SelfGuidanceAttnProcessor2_0')
        elif isinstance(attn.processor, diffusers.models.attention_processor.AttnProcessor):
          attn.processor = SelfGuidanceAttnProcessor(full_name, attn_store, sd_pipe)
          print(f'{full_name} change to SelfGuidanceAttnProcessor')
        else:
          raise NotImplementedError("Self-guidance is not implemented for this attention processor")
        
sd_pipe.to('cuda')     

if use_scheduler == 'ddpm':
  print('Using DDPM as scheduler.')
  base.scheduler = diffusers.DDPMScheduler.from_config(base.scheduler.config)
  sd_pipe.scheduler = diffusers.DDPMScheduler.from_config(sd_pipe.scheduler.config)
elif use_scheduler == 'ddim':
  base.scheduler = diffusers.DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
  sd_pipe.scheduler = diffusers.DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)

def resize(x):
  return TF.resize(x, SG_RES, antialias=True)

def save_mask(base, prompt, token, attn_path, thresh=0.1, res=[16,32]):
  obj_indices = Phrase2idx(prompt, token)
  obj_indice = obj_indices[0][0]
  for debug_step in (debug_steps + [-1]):
    _, attention_maps = save_obj_cross_attention([prompt], obj_indices, attn_store, base.tokenizer, res=res, from_where=["up"], step=debug_step, outpath=attn_path)
    soft_mask_s = attention_maps[...,obj_indice]
    hard_mask_s = SelfGuidanceEdits._attn_diff_norm(soft_mask_s, hard=True, thresh=thresh)
    save_mask_as_img(hard_mask_s, os.path.join(attn_path, f'{debug_step}_{attn_words_list[0][0]}_[{obj_indice}]_{res}_hard_mask.png'))

def get_mask(attn_store, prompt, token, thresh=0.1, res=[16,32], timestep=-1):
  if type(timestep) is tuple:
    start_t, end_t = timestep
  elif timestep == -1:
    start_t, end_t = (0, 50)
  elif type(timestep) is int:
    start_t, end_t = (timestep, timestep+1)
  else:
    raise Exception('wrong timestep form! should be int or tuple[int]')

  obj_indice = Phrase2idx(prompt, token)[0][0]
  soft_mask = sum([sum(TF.resize(aggregate_attention([prompt], attn_store, r, from_where=["up"], is_cross=True, select=0, step=t).permute(2,0,1), 64, antialias=True).permute(1,2,0) for r in res)[...,obj_indice] for t in range(start_t, end_t)]) / (end_t - start_t)
  hard_mask = SelfGuidanceEdits._attn_diff_norm(soft_mask, hard=True, thresh=thresh)  

  return soft_mask, hard_mask

def get_unique_path(path):
  if os.path.isfile(path):
    while(1):
      if os.path.exists(path):
        parts = path.rsplit('.',1)
        path = parts[0] + '_copy.' + parts[1]
      else:
        break
  elif os.path.isdir(path):
    while(1):
      if os.path.exists(path):
        path += '_copy' 
      else: break
  return path

def vis_feature_mapping(outpath, mapping_vis_step=-1):
  feature_mapping_path = get_unique_path(os.path.join(outpath, f'feature_mapping_step_{mapping_vis_step}'))
  os.makedirs(feature_mapping_path)

  CX_M_list = []
  ref_feature_list = []
  base_feature_list = []
  for step in range(50):
    mask_ref_s_32 = TF.resize(hard_mask_ref.unsqueeze(0), 32, antialias=True).squeeze(0)
    hard_mask_ref_s_32 = SelfGuidanceEdits._attn_diff_norm(mask_ref_s_32, hard=True, thresh=ref_thresh)
    ref_feature = processed_ref_aux_list[0]['feats_up_2_2']['up_blocks.2.attentions.2'][step][...,hard_mask_ref_s_32>0].permute(0,2,1).squeeze(0).cuda()
    ref_feature_list.append(ref_feature)

    mask_base_s_32 = TF.resize(Globals.get('hard_mask_base').unsqueeze(0), 32, antialias=True).squeeze(0)
    hard_mask_base_s_32 = SelfGuidanceEdits._attn_diff_norm(mask_base_s_32, hard=True, thresh=base_thresh)
    base_feature = processed_base_aux['feats_up_2_2']['up_blocks.2.attentions.2'][step][...,hard_mask_base_s_32>0].permute(0,2,1).squeeze(0).cuda()
    base_feature_list.append(base_feature)

    CX_M_list.append(remd_loss_m2(ref_feature, base_feature, splits=[base_feature.shape[-1]], return_mat=True)) 

  ref_feature_mean = sum(ref_feature_list) / len(ref_feature_list)
  base_feature_mean = sum(base_feature_list) / len(base_feature_list)
  if mapping_vis_step == -1:
    CX_M = remd_loss_m2(ref_feature_mean, base_feature_mean, splits=[ref_feature_mean.shape[-1]], return_mat=True)
  else:
    CX_M = CX_M_list[mapping_vis_step]
  m1,m1_inds = CX_M.min(1) 
  m2,m2_inds = CX_M.min(0) 
  values, indices = m2.sort()

  for i, ind in enumerate(indices):
    x_ref, y_ref = (torch.where(hard_mask_ref_s_32>0)[0][m2_inds[i]], torch.where(hard_mask_ref_s_32>0)[1][m2_inds[i]])
    ref_map = torch.zeros(hard_mask_ref_s_32.shape)
    ref_map[x_ref, y_ref] = 1.
    ref_map_512 = TF.resize(ref_map.unsqueeze(0), 512, antialias=True).squeeze(0)
    hard_ref_map_512 = SelfGuidanceEdits._attn_diff_norm(ref_map_512, hard=True, thresh=ref_thresh)

    x_base, y_base = (torch.where(hard_mask_base_s_32>0)[0][i], torch.where(hard_mask_base_s_32>0)[1][i])
    base_map = torch.zeros(hard_mask_base_s_32.shape)
    base_map[x_base, y_base] = 1.
    base_map_512 = TF.resize(base_map.unsqueeze(0), 512, antialias=True).squeeze(0)
    hard_base_map_512 = SelfGuidanceEdits._attn_diff_norm(base_map_512, hard=True, thresh=base_thresh)
    
    out_ref_copy = out_ref.copy()
    image_ref = np.array(out_ref_copy[0])
    image_ref[hard_ref_map_512>0, :] = np.tile(np.array([255,0,0]), ((hard_ref_map_512>0).sum().item(), 1))
    pil_image_ref = Image.fromarray(image_ref)

    out_base_copy = out_base.copy()
    image_base = np.array(out_base_copy[0])
    image_base[hard_base_map_512>0, :] = np.tile(np.array([255,0,0]), ((hard_base_map_512>0).sum().item(), 1))
    pil_image_base = Image.fromarray(image_base)

    concat_images([pil_image_ref, pil_image_base]).save(os.path.join(feature_mapping_path, f'{i}.png'))


def generate_img(seed, prompt, x_t=None, uncond_embeddings=None):
    sg_edits = {
        ('attn_up_2_1_0', 'feats_up_2_1'): [
           {
          'words': [class_token],
            'fn': sg_fn,
            'weight': sg_weight,
            'tgt': processed_ref_aux_list[0],
            'base_aux': processed_base_aux,
            'timestep': [feats_start, feats_end],
            'kwargs': {
              'thresh': global_thresh,
            #   'gamma': 4,
            }
          },
        ],
        ('attn_up_2_2_0', 'feats_up_2_2'): [
           {
          'words': [class_token],
            'fn': sg_fn,
            'weight': sg_weight,
            'tgt': processed_ref_aux_list[0],
            'base_aux': processed_base_aux,
            'timestep': [feats_start, feats_end],
            'kwargs': {
              'thresh': global_thresh,
            #   'gamma': 4,
            }
          },
        ],
        'attn':[{
                'words': [class_token] + bg_word,
                'fn': shape_fn,
                'kwargs': {
                  #  'norm':2,
                    # 'focal': True
                },
                'weight':100,
                'timestep': [shape_start, shape_end],
                'tgt': processed_templ_aux
            },
        ],
    }

    generator = torch.Generator().manual_seed(seed)

    attn_path = get_unique_path(os.path.join(outpath, f'attn_personalize_{seed}'))
    Globals.set('attn_path', attn_path)
    os.makedirs(attn_path, exist_ok=True)
    Globals.set('save_aux', False)
    Globals.set('new_timestep', True)
    Globals.set('DEBUG', True)
    Globals.set('step', 0)
    attn_store.reset()
    
    out_after = base(prompt=[prompt]*n_img, sg_grad_wt=sg_grad_wt, sg_edits=sg_edits, num_inference_steps=num_inference_steps, sg_loss_rescale=sg_loss_rescale, debug=False, feats_start=feats_start, feats_end=feats_end, generator=generator, latents=x_t, uncond_embeddings=uncond_embeddings).images
    
    if Globals.get('save_aux'): edit_aux = base.get_sg_aux(get_whole=True)
    pil_image_after = concat_images(out_after)
    after_pil_path = get_unique_path(os.path.join(outpath, 'personalize_'+f'{seed}.png'))
    pil_image_after.save(after_pil_path)

    pil_image = concat_images(out_ref_list + out_base + out_templ + out_after)
    pil_path = get_unique_path(os.path.join(outpath, f'result_{seed}.png'))
    pil_image.save(pil_path)

    print(f"Saved at {pil_path}")

    if Globals.get('DEBUG'):
      save_mask(base, prompt, class_token, attn_path, thresh=base_thresh, res=[16,32])
  
    soft_mask_after, hard_mask_after = get_mask(attn_store, prompt, class_token, thresh=base_thresh, res=[16,32], timestep=(15,50))
    save_mask_as_img(soft_mask_after, os.path.join(attn_path, '_soft_mask.png'))
    save_mask_as_img(hard_mask_after, os.path.join(attn_path, '_hard_mask.png'))


#####################################################################################################################
#####################################################################################################################

debug_steps = list(range(0, 10)) + list(range(10, 20, 2)) + list(range(19, 50, 10))

num_inference_steps = 50 # @param {type:"integer"}
sg_loss_rescale = 1000. # @param {type: "number"}
sg_grad_wt = 4. # @param {type: "number"}
feats_start, feats_end = (10, 20)
shape_start, shape_end = (0, 10)
global_start = min(shape_start, feats_start)
global_end = max(shape_end, feats_end)
  

n_img=1  # @param {type:"integer"}
global_thresh = 0.3
ref_thresh = 0.1
base_thresh = 0.1
sg_fn = SelfGuidanceEdits.remd2_mse
shape_fn = SelfGuidanceEdits.shape
sg_weight = 10
feature_mapping_vis = False
mapping_vis_step = 10
num_iteration = 5

base_seeds = list(range(args.num_images))

unique_token = args.unique_token
class_token = args.class_token
subject_name = args.subject_name

prompt_ref = f"a photo of a {class_token}"

prompt_personalize_list = [
'a photo of {0} {1} swimming in water'.format(unique_token, class_token),
'a {0} {1} in the jungle'.format(unique_token, class_token),
'a {0} {1} in the snow'.format(unique_token, class_token),
'a {0} {1} on the beach'.format(unique_token, class_token),
'a {0} {1} on a cobblestone street'.format(unique_token, class_token),
'a {0} {1} on top of pink fabric'.format(unique_token, class_token),
'a {0} {1} on top of a wooden floor'.format(unique_token, class_token),
'a {0} {1} with a city in the background'.format(unique_token, class_token),
'a {0} {1} with a mountain in the background'.format(unique_token, class_token),
'a {0} {1} with a blue house in the background'.format(unique_token, class_token),
'a {0} {1} on top of a purple rug in a forest'.format(unique_token, class_token),
'a {0} {1} with a wheat field in the background'.format(unique_token, class_token),
'a {0} {1} with a tree and autumn leaves in the background'.format(unique_token, class_token),
'a {0} {1} with the Eiffel Tower in the background'.format(unique_token, class_token),
'a {0} {1} floating on top of water'.format(unique_token, class_token),
'a {0} {1} floating in an ocean of milk'.format(unique_token, class_token),
'a {0} {1} on top of green grass with sunflowers around it'.format(unique_token, class_token),
'a {0} {1} on top of a mirror'.format(unique_token, class_token),
'a {0} {1} on top of the sidewalk in a crowded street'.format(unique_token, class_token),
'a {0} {1} on top of a dirt road'.format(unique_token, class_token),
'a {0} {1} on top of a white rug'.format(unique_token, class_token),
'a red {0} {1}'.format(unique_token, class_token),
'a purple {0} {1}'.format(unique_token, class_token),
'a shiny {0} {1}'.format(unique_token, class_token),
'a wet {0} {1}'.format(unique_token, class_token),
'a cube shaped {0} {1}'.format(unique_token, class_token)
]
attn_words_list = [[class_token]]

bg_word_list = [['jungle'], ['snow'], ['beach'], ['cobblestone', 'street'], ['pink', 'fabric'], 
                ['wooden', 'floor'], ['city'], ['mountain'], ['blue', 'house'], 
                ['purple', 'rug', 'forest'], ['wheat', 'field'], ['tree', 'autumn', 'leaves'], 
                ['Eiffel', 'Tower'], ['water'], ['ocean', 'milk'], ['green', 'grass', 'sunflowers'],
                ['mirror'], ['sidewalk', 'street'], ['dirt', 'road'], ['white', 'rug'], 
                [''], [''], [''], [''], ['']]

ref_path = args.ref_dir


inversion_folder = os.path.join('inversion', args.pm_pretrained_model_name_or_path.replace('/', '_'))
os.makedirs(inversion_folder, exist_ok=True)
x_t_path = os.path.join(inversion_folder, f"{subject_name}_x_t.pt")
uncond_embeddings_path = os.path.join(inversion_folder, f"{subject_name}_uncond_embeddings.pt")
if not (os.path.exists(x_t_path) and os.path.exists(uncond_embeddings_path)):
  Globals.set('save_aux', False)
  Globals.set('new_timestep', True)
  Globals.set('DEBUG', False)
  Globals.set('step', 0)
  image = load_512(ref_path)
  pil_image = view_images(image)
  null_inversion = NullInversion(base)
  (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(ref_path, prompt_ref, offsets=(0,0,0,0), verbose=True, DDIM=False)
  
  torch.save(x_t, x_t_path)
  torch.save(uncond_embeddings, uncond_embeddings_path)

base.register_hook()

aux_idx=0

for prompt_personalize, bg_word in zip(prompt_personalize_list, bg_word_list):
  obj_indices = Phrase2idx(prompt_personalize, class_token)
  obj_indice = obj_indices[0][0]
  if bg_word == ['']:
      bg_word = []
  print(f'\nbg_indice:{bg_word}\n')
  custom_obj_str = subject_name + str(sg_weight)
  outpath = os.path.join(f'templ_output_2-12_test_{sg_fn.__name__}_{feats_start}_{feats_end}_{global_thresh}_{shape_fn.__name__}_{shape_start}_{shape_end}_global_iter_{num_iteration}', custom_obj_str, f'{prompt_personalize}')
  print(f'outpath: {outpath}')

  ref_aux_list = []
  out_ref_list = []
  x_t_path = os.path.join(inversion_folder, f"{subject_name}_x_t.pt")
  uncond_embeddings_path = os.path.join(inversion_folder, f"{subject_name}_uncond_embeddings.pt")
  prompt_text_ids = base.tokenizer(prompt_ref, return_tensors='np')['input_ids'][0]
  Globals.set('prompt_text_ids', prompt_text_ids)
  
  x_t = torch.load(x_t_path)
  uncond_embeddings = torch.load(uncond_embeddings_path)

  generator = torch.Generator().manual_seed(0)
  
  Globals.set('save_aux', True)
  Globals.set('new_timestep', True)
  Globals.set('DEBUG', False)
  Globals.set('step', 0)
  attn_store.reset()
  attn_path = os.path.join(outpath, f"attn_{subject_name}")
  os.makedirs(attn_path, exist_ok=True)
  out_ref = base(prompt=[prompt_ref]*n_img, num_inference_steps=num_inference_steps, generator=generator, latents=x_t, uncond_embeddings=uncond_embeddings, debug=False).images
  if Globals.get('save_aux'): ref_aux_list.append(base.get_sg_aux(get_whole=True))
  pil_image_ref = concat_images(out_ref)
  out_ref_list += out_ref
  ref_pil_path = os.path.join(outpath, 'ref_'+f'{prompt_ref}_{0}.png')
  pil_image_ref.save(ref_pil_path)
  if Globals.get('DEBUG'):
    save_mask(base, prompt_ref, class_token, attn_path, thresh=ref_thresh, res=[16,32])
  soft_mask_ref, hard_mask_ref = get_mask(attn_store, prompt_ref, class_token, thresh=ref_thresh, res=[16,32], timestep=(15,50))
  save_mask_as_img(soft_mask_ref, os.path.join(attn_path, '_soft_mask.png'))
  save_mask_as_img(hard_mask_ref, os.path.join(attn_path, '_hard_mask.png'))
  try: processed_ref_aux_list = [{k:torch.utils._pytree.tree_map(lambda x: x[aux_idx:aux_idx+1].repeat_interleave(n_img, 0).cpu(), v) for k,v in ref_aux.items()} for ref_aux in ref_aux_list]
  except: pass
  #####################################################################################################################
  #####################################################################################################################

  for base_seed in base_seeds:
    generator = torch.Generator().manual_seed(base_seed)
    attn_path = os.path.join(outpath, f'attn_base_{base_seed}')
    Globals.set('attn_path', attn_path)
    os.makedirs(attn_path, exist_ok=True)
    Globals.set('save_aux', True)
    Globals.set('new_timestep', True)
    Globals.set('DEBUG', False)
    Globals.set('step', 0)
    attn_store.reset()

    out_base = base(prompt=[prompt_personalize]*n_img, num_inference_steps=num_inference_steps, generator=generator, debug=False).images
  
    if Globals.get('save_aux'): base_aux = base.get_sg_aux(get_whole=True)
    pil_image_base = concat_images(out_base)
    base_pil_path = os.path.join(outpath, f'base_{prompt_personalize}_{base_seed}.png')
    pil_image_base.save(base_pil_path)
    if Globals.get('DEBUG'):
      save_mask(base, prompt_personalize, class_token, attn_path, thresh=base_thresh, res=[16,32])
    soft_mask_base, hard_mask_base = get_mask(attn_store, prompt_personalize, class_token, thresh=base_thresh, res=[16,32], timestep=(15,50))
    Globals.set('hard_mask_base', hard_mask_base)
    save_mask_as_img(soft_mask_base, os.path.join(attn_path, '_soft_mask.png'))
    save_mask_as_img(hard_mask_base, os.path.join(attn_path, '_hard_mask.png'))
    try: processed_base_aux = {k:torch.utils._pytree.tree_map(lambda x: x[aux_idx:aux_idx+1].repeat_interleave(n_img, 0).cpu(), v) for k,v in base_aux.items()}
    except: pass
    if feature_mapping_vis:
      vis_feature_mapping(outpath, mapping_vis_step=mapping_vis_step)
    
    # sd as template
    generator = torch.Generator().manual_seed(base_seed)
    attn_path = os.path.join(outpath, f'attn_templ_{base_seed}')
    Globals.set('attn_path', attn_path)
    os.makedirs(attn_path, exist_ok=True)
    Globals.set('save_aux', True)
    Globals.set('new_timestep', True)
    Globals.set('DEBUG', False)
    Globals.set('step', 0)
    attn_store.reset()

    out_templ = sd_pipe(prompt=[prompt_personalize]*n_img, num_inference_steps=num_inference_steps, generator=generator, debug=False).images

    if Globals.get('save_aux'): templ_aux = sd_pipe.get_sg_aux(get_whole=True)
    pil_image_templ = concat_images(out_templ)
    templ_pil_path = os.path.join(outpath, f'templ_{prompt_personalize}_{base_seed}.png')
    pil_image_templ.save(templ_pil_path)
    if Globals.get('DEBUG'):
      save_mask(sd_pipe, prompt_personalize, class_token, attn_path, thresh=base_thresh, res=[16,32])
    soft_mask_templ, hard_mask_templ = get_mask(attn_store, prompt_personalize, class_token, thresh=base_thresh, res=[16,32], timestep=(15,50))
    save_mask_as_img(soft_mask_templ, os.path.join(attn_path, '_soft_mask.png'))
    save_mask_as_img(hard_mask_templ, os.path.join(attn_path, '_hard_mask.png'))
    try: processed_templ_aux = {k:torch.utils._pytree.tree_map(lambda x: x[aux_idx:aux_idx+1].repeat_interleave(n_img, 0).cpu(), v) for k,v in templ_aux.items()}
    except: pass

    generate_img(base_seed, prompt_personalize)
